import torch
from torch import nn
import torch.nn.functional as F

from cbml_benchmark.losses.registry import LOSS


@LOSS.register('cbml_loss')
class CBMLLoss(nn.Module):
    def __init__(self, cfg):
        super(CBMLLoss, self).__init__()
        self.pos_a = cfg.LOSSES.CBML_LOSS.POS_A
        self.pos_b = cfg.LOSSES.CBML_LOSS.POS_B
        self.neg_a = cfg.LOSSES.CBML_LOSS.NEG_A
        self.neg_b = cfg.LOSSES.CBML_LOSS.NEG_B
        self.margin = cfg.LOSSES.CBML_LOSS.MARGIN
        self.weight = cfg.LOSSES.CBML_LOSS.WEIGHT
        self.hyper_weight = cfg.LOSSES.CBML_LOSS.HYPER_WEIGHT
        self.adaptive_neg = cfg.LOSSES.CBML_LOSS.ADAPTIVE_NEG
        self.type = cfg.LOSSES.CBML_LOSS.TYPE
        self.loss_weight_p = cfg.LOSSES.CBML_LOSS.WEIGHT_P
        self.loss_weight_n = cfg.LOSSES.CBML_LOSS.WEIGHT_N

        self.eps = 0.05
        self.max_iter = 100
        self.use_uniform = False


    def normalize_all(self, x, y, x_mean, y_mean):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        x_mean = F.normalize(x_mean, dim=1)
        y_mean = F.normalize(y_mean, dim=1)
        return x, y, x_mean, y_mean

    def cross_attention(self, x, y, x_mean, y_mean):
        N, C = x.shape[:2]
        x = x.view(N, C, -1)
        y = y.view(N, C, -1)

        att = F.relu(torch.einsum("nc,ncr->nr", x_mean, y)).view(N, -1)
        u = att / (att.sum(dim=1, keepdims=True) + 1e-5)
        att = F.relu(torch.einsum("nc,ncr->nr", y_mean, x)).view(N, -1)
        v = att / (att.sum(dim=1, keepdims=True) + 1e-5)
        return u, v

    def pair_wise_wdist(self, x, y):
        B, C, H, W = x.size()
        x_mean = x.mean([2, 3])
        y_mean = y.mean([2, 3])
        x = x.view(B, C, -1)
        y = y.view(B, C, -1)
        x, y, x_mean, y_mean = self.normalize_all(x, y, x_mean, y_mean)

        if self.use_uniform:
            u = torch.zeros(B, H * W, dtype=x.dtype, device=x.device).fill_(1. / (H * W))
            v = torch.zeros(B, H * W, dtype=x.dtype, device=x.device).fill_(1. / (H * W))
        else:
            u, v = self.cross_attention(x, y, x_mean, y_mean)

        sim1 = torch.einsum('bcs, bcm->bsm', x, y).contiguous()
        sim2 = torch.einsum('bc, bc->b', x_mean, y_mean).contiguous().reshape(B, 1, 1)

        wdist = 1.0 - sim1.view(B, H * W, H * W)

        with torch.no_grad():
            K = torch.exp(-wdist / self.eps)
            T = self.Sinkhorn(K, u, v).detach()

        sim = (sim1 + sim2) / 2
        sim = torch.sum(T * sim, dim=(1, 2))

        return sim

    def Sinkhorn(self, K, u, v):
        r = torch.ones_like(u)
        c = torch.ones_like(v)
        thresh = 1e-1
        for i in range(self.max_iter):
            r0 = r
            r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1)
            c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1)
            err = (r - r0).abs().mean(dim=1)
            err = err[~torch.isnan(err)]

            if len(err) == 0 or torch.max(err).item() < thresh:
                break

        T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K
        return T


    def forward(self, feats, labels):
        assert feats.size(0) == labels.size(0), \
            f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"

        # b, _, _, _ = batch.size()
        # batch_repeat = torch.repeat_interleave(batch, b, dim=0)
        # batch_cat = torch.cat([batch for _ in range(b)], dim=0)
        # similarity = self.pair_wise_wdist(batch_repeat, batch_cat).view(b, b)

        feats = F.normalize(feats,p=2,dim=1)

        batch_size = feats.size(0)
        sim_mat = torch.matmul(feats, torch.t(feats))
        epsilon = 1e-5
        loss = list()

        # print(similarity.size(),sim_mat.size())

        for i in range(batch_size):

            pos_pair_ = sim_mat[i][labels == labels[i]]
            pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
            neg_pair_ = sim_mat[i][labels != labels[i]]

            if len(neg_pair_) < 1 or len(pos_pair_) < 1:
                continue

            # mean_ = torch.mean(sim_mat[i])
            # mean_ = self.hyper_weight * torch.mean(pos_pair_) + (1 - self.hyper_weight) * torch.mean(neg_pair_)
            mean_ = (1.-self.hyper_weight)*torch.mean(sim_mat[i]) + self.hyper_weight*(torch.min(pos_pair_) + torch.max(neg_pair_)) / 2.
            # sigma_ = torch.mean(torch.sum(torch.pow(sim_mat[i]-mean_,2)))
            sigma_ = torch.mean(torch.sum(torch.pow(neg_pair_-mean_,2)))

            pp = pos_pair_ - self.margin < torch.max(neg_pair_)
            pos_pair = pos_pair_[pp]
            if self.adaptive_neg:
                np = neg_pair_ + self.margin > torch.min(pos_pair_)
                neg_pair = neg_pair_[np]
            else:
                np = torch.argsort(neg_pair_)
                neg_pair = neg_pair_[np[-100:]]

            if len(neg_pair) < 1 or len(pos_pair) < 1:
                # loss.append(pos_sigma_ + neg_sigma_)
                continue

            # pos_patch_pair_ = similarity[i][labels == labels[i]]
            # pos_patch_pair_ = pos_patch_pair_[pos_patch_pair_ < 1 - epsilon]
            # neg_patch_pair_ = similarity[i][labels != labels[i]]
            #
            # if len(neg_patch_pair_) < 1 or len(pos_patch_pair_) < 1:
            #     continue
            #
            # # mean_ = torch.mean(sim_mat[i])
            # # mean_ = self.hyper_weight * torch.mean(pos_pair_) + (1 - self.hyper_weight) * torch.mean(neg_pair_)
            # patch_mean_ = (1. - self.hyper_weight) * torch.mean(similarity[i]) + self.hyper_weight * (
            #             torch.min(pos_patch_pair_) + torch.max(neg_patch_pair_)) / 2.
            # # sigma_ = torch.mean(torch.sum(torch.pow(sim_mat[i]-mean_,2)))
            # patch_sigma_ = torch.mean(torch.sum(torch.pow(neg_patch_pair_ - patch_mean_, 2)))
            #
            # pp = pos_patch_pair_ - self.margin < torch.max(neg_patch_pair_)
            # pos_patch_pair = pos_patch_pair_[pp]
            # if self.adaptive_neg:
            #     np = neg_patch_pair_ + self.margin > torch.min(pos_patch_pair_)
            #     neg_patch_pair = neg_patch_pair_[np]
            # else:
            #     np = torch.argsort(neg_patch_pair_)
            #     neg_patch_pair = neg_patch_pair_[np[-100:]]
            #
            # if len(neg_patch_pair) < 1 or len(pos_patch_pair) < 1:
            #     # loss.append(pos_sigma_ + neg_sigma_)
            #     continue

            if self.type == 'log' or self.type == 'sqrt':
                fp = 1. + torch.sum(torch.exp(-1./self.pos_b * (pos_pair - self.pos_a)))
                fn = 1. + torch.sum(torch.exp( 1./self.neg_b * (neg_pair - self.neg_a)))
                # fp1 = 1. + torch.sum(torch.exp(-1. / self.pos_b * (pos_patch_pair - self.pos_a)))
                # fn1 = 1. + torch.sum(torch.exp(1. / self.neg_b * (neg_patch_pair - self.neg_a)))
                if self.type == 'log':
                    pos_loss = torch.log(fp)# + torch.log(fp1)
                    neg_loss = torch.log(fn)# + torch.log(fn1)
                else:
                    pos_loss = torch.sqrt(fp)
                    neg_loss = torch.sqrt(fn)
            else:
                pos_loss = 1. + self.loss_weight_p*torch.sum(torch.exp(-1. / self.pos_b * (pos_pair - self.pos_a)))
                neg_loss = 1. + self.loss_weight_n*torch.sum(torch.exp(1. / self.neg_b * (neg_pair - self.neg_a)))
            pos_neg_loss = sigma_ #+ patch_sigma_ #torch.abs(mean_-mean) + torch.abs(sigma_-sigma)
            loss.append((pos_loss + neg_loss + self.weight*pos_neg_loss))

        if len(loss) == 0:
            return torch.zeros(1, requires_grad=True).cuda()

        loss = sum(loss) / batch_size
        return loss
